# -*- mode: python -*-
# =============================================================================
#  @@-COPYRIGHT-START-@@
#
#  Copyright (c) 2023, Qualcomm Innovation Center, Inc. All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#
#  1. Redistributions of source code must retain the above copyright notice,
#     this list of conditions and the following disclaimer.
#
#  2. Redistributions in binary form must reproduce the above copyright notice,
#     this list of conditions and the following disclaimer in the documentation
#     and/or other materials provided with the distribution.
#
#  3. Neither the name of the copyright holder nor the names of its contributors
#     may be used to endorse or promote products derived from this software
#     without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
#  POSSIBILITY OF SUCH DAMAGE.
#
#  SPDX-License-Identifier: BSD-3-Clause
#
#  @@-COPYRIGHT-END-@@
# =============================================================================
# pylint: disable=redefined-builtin
"""Common utility functions"""

from typing import Callable, Tuple, Any
import functools
import itertools
from packaging import version

import torch


def _is_expandable(src_shape: Tuple[int, ...], target_shape: Tuple[int, ...]) -> bool:
    """
    Returns true if source shape can be expanded as target shape
    """
    if len(src_shape) > len(target_shape):
        return False

    for src_dim, dst_dim in zip(src_shape[::-1], target_shape[::-1]):
        if src_dim not in (1, dst_dim):
            return False

    return True


def _is_reducible(src_shape: Tuple[int, ...], target_shape: Tuple[int, ...]) -> bool:
    """
    Returns true if source shape can be reduced as target shape
    """
    return _is_expandable(target_shape, src_shape)  # pylint: disable=arguments-out-of-order


def reduce(input: torch.Tensor, shape: Tuple[int, ...], reduce_op: Callable):
    """
    Reduce input into given shape.

    :param input: Input to reduce
    :param shape: Shape of the reduced output
    :param reduce_op: Reduce operation
    """
    if not _is_reducible(input.shape, shape):
        raise RuntimeError(
            f"Input of shape {list(input.shape)} can't be reduced to shape {list(shape)}"
        )

    padded_shape = (*itertools.repeat(1, len(input.shape) - len(shape)), *shape)
    reduce_dims = tuple(axis for axis, dim in enumerate(padded_shape) if dim == 1)
    other_dims = tuple(axis for axis, dim in enumerate(padded_shape) if dim > 1)
    permute_dims = reduce_dims + other_dims

    return reduce_op(
        input.permute(permute_dims).reshape(-1, *shape), dim=0, keepdim=False
    )


class _ContextManager:
    def __init__(self, action: Callable[[], Any], cleanup: Callable[[], Any]):
        self._action = action
        self._cleanup = cleanup

    def __enter__(self):
        self._action()
        return self

    def __exit__(self, *_):
        self._cleanup()

    def __call__(self, fn: Callable):
        @functools.wraps(fn)
        def wrapper(*args, **kwargs):
            with self:
                return fn(*args, **kwargs)

        return wrapper


class _NullAttribute:
    pass


def patch_attr(obj, attr_name, new_attr) -> _ContextManager:
    """
    Temporarily overwrite object attribute
    """
    if isinstance(obj, torch.nn.Module):
        if attr_name in obj._parameters or attr_name in obj._buffers:  # pylint: disable=protected-access
            return _patch_param_or_buffer(obj, attr_name, new_attr)

    if hasattr(obj, attr_name):
        old_attr = getattr(obj, attr_name)
    else:
        old_attr = _NullAttribute()
    action = lambda: setattr(obj, attr_name, new_attr)

    def cleanup():
        try:
            delattr(obj, attr_name)
        except AttributeError:
            pass

        if not hasattr(obj, attr_name) and not isinstance(old_attr, _NullAttribute):
            setattr(obj, attr_name, old_attr)

    return _ContextManager(action, cleanup)


def _patch_param_or_buffer(
    module: torch.nn.Module,
    param_or_buffer_name: str,
    new_param_or_buffer: torch.Tensor,
):
    """
    Temporarily substitute the reference to the a parameter with the quantized parameter.
    Under the scope of this function, ``getattr(module, param_or_buffer_name)`` will return
    ``new_param_or_buffer`` instead of the original parameter.

    :param module: Module that owns the parameter
    :param param_or_buffer_name: Name of the parameter
    :param new_param_or_buffer: New parameter to replace the original parameter
    """
    # pylint: disable=protected-access

    orig_param_or_buffer = getattr(module, param_or_buffer_name)
    if orig_param_or_buffer is not None:
        assert new_param_or_buffer.shape == orig_param_or_buffer.shape

    if param_or_buffer_name in module._parameters:
        container = module._parameters
    elif param_or_buffer_name in module._buffers:
        container = module._buffers
    elif param_or_buffer_name in module.__dict__:
        # Some non-standard modules (e.g. replicas of torch.nn.DataParallel) store their parameters
        container = module.__dict__
    else:
        raise RuntimeError(
            f"'{param_or_buffer_name}' is not a valid name of parameter of buffer of {type(module)}."
        )

    action = lambda: container.update({param_or_buffer_name: new_param_or_buffer})
    cleanup = lambda: container.update({param_or_buffer_name: orig_param_or_buffer})

    return _ContextManager(action, cleanup)


class _StraightThroughEstimator(torch.autograd.Function):  # pylint: disable=abstract-method
    @staticmethod
    def forward(ctx, op, *args, **kwargs):  # pylint:disable=arguments-differ, unused-argument
        return op(*args, **kwargs)

    @staticmethod
    def backward(ctx, *grad):
        return (None, *grad)


def ste_round(*args, **kwargs):
    """
    Applies straight-through rounding
    """
    return _StraightThroughEstimator.apply(torch.round, *args, **kwargs)


class StatisticsNotFoundError(RuntimeError):
    """
    Error raised when compute_encodings() is invoked without statistics
    """


_ENABLE_RECOMPUTE = False


def _set_enable_recompute(mode: bool):
    original_mode = _ENABLE_RECOMPUTE

    def action():
        global _ENABLE_RECOMPUTE  # pylint: disable=global-statement
        _ENABLE_RECOMPUTE = mode

    def cleanup():
        global _ENABLE_RECOMPUTE  # pylint: disable=global-statement
        _ENABLE_RECOMPUTE = original_mode

    return _ContextManager(action, cleanup)


def is_recompute_enabled():
    """
    Returns True if recomputation for memory saving is enabled; False otherwise.
    """
    return _ENABLE_RECOMPUTE


def enable_recompute():
    """
    Enable recomputation for memory saving.
    """
    return _set_enable_recompute(True)


def no_recompute():
    """
    Disable recomputation for memory saving.
    """
    return _set_enable_recompute(False)


def allow_recompute(fn):
    """
    Allow recomputation of activation of the given function during training
    if recompute is enabled.
    """

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        if is_recompute_enabled():
            # Enable activation recompute (a.k.a. activataion checkpointing)
            # to reduce memory footprint of training
            return torch.utils.checkpoint.checkpoint(
                fn, *args, use_reentrant=False, **kwargs
            )
        return fn(*args, **kwargs)

    return wrapper


def flatten_nn_module_list(module):
    """
    Flatten nested list of nn.Modules into a flat list
    """

    def flat_iter(mod):
        if isinstance(mod, (list, tuple, torch.nn.ModuleList)):
            for x in mod:
                yield from flat_iter(x)
        else:
            yield mod

    return list(flat_iter(module))


def _map_qmodule(modules, func):
    # pylint: disable=import-outside-toplevel
    # pylint: disable=protected-access, cyclic-import
    from aimet_torch.v2.nn import BaseQuantizationMixin

    contexts = []
    ctx = _ContextManager(
        action=lambda: None,
        cleanup=lambda: [context._cleanup() for context in contexts],
    )

    if isinstance(modules, torch.nn.Module):
        modules = [modules]

    try:
        for module_elem in modules:
            for module in module_elem.modules():
                if isinstance(module, BaseQuantizationMixin):
                    context = func(module)
                    contexts.append(context)
    except Exception:
        ctx._cleanup()
        raise

    return ctx


def remove_input_quantizers(modules):
    """
    Temporarily remove all input quantizers

    Example:

        >>> print(sim.model)
        Sequential(
          (0): QuantizedConv2d(
            3, 3, kernel_size=(3, 3), stride=(1, 1)
            (param_quantizers): ModuleDict(
              (weight): QuantizeDequantize(shape=(3, 1, 1, 1), qmin=-128, qmax=127, symmetric=True)
              (bias): None
            )
            (input_quantizers): ModuleList(
              (0): QuantizeDequantize(shape=(), qmin=0, qmax=255, symmetric=False)
            )
            (output_quantizers): ModuleList(
              (0): QuantizeDequantize(shape=(), qmin=0, qmax=255, symmetric=False)
            )
          )
        )
        >>> with remove_input_quantizers(sim.model):
        ...     print(sim.model)
        ...
        Sequential(
          (0): QuantizedConv2d(
            3, 3, kernel_size=(3, 3), stride=(1, 1)
            (param_quantizers): ModuleDict(
              (weight): QuantizeDequantize(shape=(3, 1, 1, 1), qmin=-128, qmax=127, symmetric=True)
              (bias): None
            )
            (input_quantizers): ModuleList(
              (0): None
            )
            (output_quantizers): ModuleList(
              (0): QuantizeDequantize(shape=(), qmin=0, qmax=255, symmetric=False)
            )
          )
        )
    """
    # pylint: disable=protected-access
    return _map_qmodule(modules, lambda qmodule: qmodule._remove_input_quantizers())


def remove_output_quantizers(modules):
    """
    Temporarily remove all output quantizers

    Example:

        >>> print(sim.model)
        Sequential(
          (0): QuantizedConv2d(
            3, 3, kernel_size=(3, 3), stride=(1, 1)
            (param_quantizers): ModuleDict(
              (weight): QuantizeDequantize(shape=(3, 1, 1, 1), qmin=-128, qmax=127, symmetric=True)
              (bias): None
            )
            (input_quantizers): ModuleList(
              (0): QuantizeDequantize(shape=(), qmin=0, qmax=255, symmetric=False)
            )
            (output_quantizers): ModuleList(
              (0): QuantizeDequantize(shape=(), qmin=0, qmax=255, symmetric=False)
            )
          )
        )
        >>> with remove_output_quantizers(sim.model):
        ...     print(sim.model)
        ...
        Sequential(
          (0): QuantizedConv2d(
            3, 3, kernel_size=(3, 3), stride=(1, 1)
            (param_quantizers): ModuleDict(
              (weight): QuantizeDequantize(shape=(3, 1, 1, 1), qmin=-128, qmax=127, symmetric=True)
              (bias): None
            )
            (input_quantizers): ModuleList(
              (0): QuantizeDequantize(shape=(), qmin=0, qmax=255, symmetric=False)
            )
            (output_quantizers): ModuleList(
              (0): None
            )
          )
        )
    """
    # pylint: disable=protected-access
    return _map_qmodule(modules, lambda qmodule: qmodule._remove_output_quantizers())


def remove_param_quantizers(modules):
    """
    Temporarily remove all parameter quantizers

    Example:

        >>> print(sim.model)
        Sequential(
          (0): QuantizedConv2d(
            3, 3, kernel_size=(3, 3), stride=(1, 1)
            (param_quantizers): ModuleDict(
              (weight): QuantizeDequantize(shape=(3, 1, 1, 1), qmin=-128, qmax=127, symmetric=True)
              (bias): None
            )
            (input_quantizers): ModuleList(
              (0): QuantizeDequantize(shape=(), qmin=0, qmax=255, symmetric=False)
            )
            (output_quantizers): ModuleList(
              (0): QuantizeDequantize(shape=(), qmin=0, qmax=255, symmetric=False)
            )
          )
        )
        >>> with remove_param_quantizers(sim.model):
        ...     print(sim.model)
        ...
        Sequential(
          (0): QuantizedConv2d(
            3, 3, kernel_size=(3, 3), stride=(1, 1)
            (param_quantizers): ModuleDict(
              (weight): None
              (bias): None
            )
            (input_quantizers): ModuleList(
              (0): QuantizeDequantize(shape=(), qmin=0, qmax=255, symmetric=False)
            )
            (output_quantizers): ModuleList(
              (0): QuantizeDequantize(shape=(), qmin=0, qmax=255, symmetric=False)
            )
          )
        )
    """
    # pylint: disable=protected-access
    return _map_qmodule(modules, lambda qmodule: qmodule._remove_param_quantizers())


def remove_activation_quantizers(modules):
    """
    Temporarily remove all input and output quantizers

    Example:

        >>> print(sim.model)
        Sequential(
          (0): QuantizedConv2d(
            3, 3, kernel_size=(3, 3), stride=(1, 1)
            (param_quantizers): ModuleDict(
              (weight): QuantizeDequantize(shape=(3, 1, 1, 1), qmin=-128, qmax=127, symmetric=True)
              (bias): None
            )
            (input_quantizers): ModuleList(
              (0): QuantizeDequantize(shape=(), qmin=0, qmax=255, symmetric=False)
            )
            (output_quantizers): ModuleList(
              (0): QuantizeDequantize(shape=(), qmin=0, qmax=255, symmetric=False)
            )
          )
        )
        >>> with remove_activation_quantizers(sim.model):
        ...     print(sim.model)
        ...
        Sequential(
          (0): QuantizedConv2d(
            3, 3, kernel_size=(3, 3), stride=(1, 1)
            (param_quantizers): ModuleDict(
              (weight): QuantizeDequantize(shape=(3, 1, 1, 1), qmin=-128, qmax=127, symmetric=True)
              (bias): None
            )
            (input_quantizers): ModuleList(
              (0): None
            )
            (output_quantizers): ModuleList(
              (0): None
            )
          )
        )
    """
    if not isinstance(modules, torch.nn.Module):
        # Shallow copy in case modules is an iterator
        modules = list(modules)

    context_1 = remove_input_quantizers(modules)
    context_2 = remove_output_quantizers(modules)
    # pylint: disable=protected-access
    return _ContextManager(
        action=lambda: None,
        cleanup=lambda: (context_1._cleanup(), context_2._cleanup()),
    )


def remove_all_quantizers(modules):
    """
    Temporarily remove all quantizers

    Example:

        >>> print(sim.model)
        Sequential(
          (0): QuantizedConv2d(
            3, 3, kernel_size=(3, 3), stride=(1, 1)
            (param_quantizers): ModuleDict(
              (weight): QuantizeDequantize(shape=(3, 1, 1, 1), qmin=-128, qmax=127, symmetric=True)
              (bias): None
            )
            (input_quantizers): ModuleList(
              (0): QuantizeDequantize(shape=(), qmin=0, qmax=255, symmetric=False)
            )
            (output_quantizers): ModuleList(
              (0): QuantizeDequantize(shape=(), qmin=0, qmax=255, symmetric=False)
            )
          )
        )
        >>> with remove_all_quantizers(sim.model):
        ...     print(sim.model)
        ...
        Sequential(
          (0): QuantizedConv2d(
            3, 3, kernel_size=(3, 3), stride=(1, 1)
            (param_quantizers): ModuleDict(
              (weight): None
              (bias): None
            )
            (input_quantizers): ModuleList(
              (0): None
            )
            (output_quantizers): ModuleList(
              (0): None
            )
          )
        )
    """
    if not isinstance(modules, torch.nn.Module):
        # Shallow copy in case modules is an iterator
        modules = list(modules)

    context_1 = remove_activation_quantizers(modules)
    context_2 = remove_param_quantizers(modules)
    # pylint: disable=protected-access
    return _ContextManager(
        action=lambda: None,
        cleanup=lambda: (context_1._cleanup(), context_2._cleanup()),
    )


def has_no_quantizers(module, ignore_params: bool = False) -> bool:
    """
    Helper function to check if a module has any quantizers enabled
    """
    return (
        all(inp_qtzr is None for inp_qtzr in module.input_quantizers)
        and all(out_qtzr is None for out_qtzr in module.output_quantizers)
        and (
            ignore_params
            or all(
                param_qtzr is None for param_qtzr in module.param_quantizers.values()
            )
        )
    )


def rgetattr(obj, attr):
    """Drop in replacement for __getattr__ that can handle dotted attribute strings"""
    return functools.reduce(getattr, [obj] + attr.split("."))


def rsetattr(obj, attr, val):
    """Drop in replacement for __setattr__ that can handle dotted attribute strings"""
    pre, _, post = attr.rpartition(".")
    pre_obj = rgetattr(obj, pre) if pre else obj
    return setattr(pre_obj, post, val)


def apply_fn_recursively_to_all_elems(fn, container):
    """Apply fn to all elements in recursively composed container"""
    if container is None:
        return None
    if isinstance(container, (list, tuple)):
        return [apply_fn_recursively_to_all_elems(fn, elem) for elem in container]
    if isinstance(container, dict):
        return {
            key: apply_fn_recursively_to_all_elems(fn, elem)
            for key, elem in container.items()
        }
    return fn(container)


def flatten_list(container):
    """Helper function to flatten nested list/tuple into 1D"""
    if not container:
        return container
    if not isinstance(container, (list, tuple)):
        return [container]
    if isinstance(container[0], (list, tuple)):
        return flatten_list(container[0]) + flatten_list(container[1:])
    if len(container) == 1:
        return container
    return container[:1] + flatten_list(container[1:])


def default_forward_fn(model, inputs):
    """
    Default forward function.
    :param model: pytorch model
    :param inputs: model inputs
    """
    if isinstance(inputs, torch.Tensor):
        inputs = [inputs]
    return model(*inputs)


_torch_compiler_is_compiling: Callable[[], bool]
_torch_compiler_is_dynamo_compiling: Callable[[], bool]
_torch_compiler_is_exporting: Callable[[], bool]

if version.parse(torch.__version__) >= version.parse("2.7"):
    _torch_compiler_is_compiling = torch.compiler.is_compiling
    _torch_compiler_is_dynamo_compiling = torch.compiler.is_dynamo_compiling
    _torch_compiler_is_exporting = torch.compiler.is_exporting
else:
    # torch < 2.7.0 doesn't have torch.compiler.is_compiling/exporting API
    def _torch_compiler_is_compiling() -> bool:
        return False

    def _torch_compiler_is_dynamo_compiling() -> bool:
        return False

    def _torch_compiler_is_exporting() -> bool:
        return False
